{
 "cells": [
  {
   "cell_type": "markdown",
   "source": [
    "# Import Data"
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "source": [
    "import pandas as pd\r\n",
    "from sklearn.tree import DecisionTreeClassifier\r\n",
    "from sklearn.model_selection import train_test_split\r\n",
    "from sklearn.metrics import accuracy_score\r\n",
    "from sklearn.externals import joblib\r\n",
    "# DeprecationWarning: sklearn.externals.joblib is \\\r\n",
    "# deprecated in 0.21 and will be removed in 0.23. \\\r\n",
    "# Please import this functionality directly from \\\r\n",
    "# joblib, which can be installed with: pip install \\\r\n",
    "# joblib. If this warning is raised when loading \\\r\n",
    "# pickled models, you may need to re-serialize those \\\r\n",
    "# models with scikit-learn 0.21+.\r\n",
    "\r\n",
    "\r\n",
    "music_data = pd.read_csv('music.csv')\r\n",
    "music_data\r\n"
   ],
   "outputs": [
    {
     "output_type": "stream",
     "name": "stderr",
     "text": [
      "D:\\Anaconda3\\lib\\site-packages\\sklearn\\externals\\joblib\\__init__.py:15: DeprecationWarning: sklearn.externals.joblib is deprecated in 0.21 and will be removed in 0.23. Please import this functionality directly from joblib, which can be installed with: pip install joblib. If this warning is raised when loading pickled models, you may need to re-serialize those models with scikit-learn 0.21+.\n",
      "  warnings.warn(msg, category=DeprecationWarning)\n"
     ]
    },
    {
     "output_type": "execute_result",
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>age</th>\n",
       "      <th>gender</th>\n",
       "      <th>genre</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>0</td>\n",
       "      <td>20</td>\n",
       "      <td>1</td>\n",
       "      <td>HipHop</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1</td>\n",
       "      <td>23</td>\n",
       "      <td>1</td>\n",
       "      <td>HipHop</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2</td>\n",
       "      <td>25</td>\n",
       "      <td>1</td>\n",
       "      <td>Jazz</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3</td>\n",
       "      <td>26</td>\n",
       "      <td>1</td>\n",
       "      <td>Jazz</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4</td>\n",
       "      <td>29</td>\n",
       "      <td>1</td>\n",
       "      <td>Jazz</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>5</td>\n",
       "      <td>30</td>\n",
       "      <td>1</td>\n",
       "      <td>Jazz</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>6</td>\n",
       "      <td>31</td>\n",
       "      <td>1</td>\n",
       "      <td>Classical</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>7</td>\n",
       "      <td>33</td>\n",
       "      <td>1</td>\n",
       "      <td>Classical</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>8</td>\n",
       "      <td>37</td>\n",
       "      <td>1</td>\n",
       "      <td>Classical</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>9</td>\n",
       "      <td>20</td>\n",
       "      <td>0</td>\n",
       "      <td>Dance</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>10</td>\n",
       "      <td>21</td>\n",
       "      <td>0</td>\n",
       "      <td>Dance</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>11</td>\n",
       "      <td>25</td>\n",
       "      <td>0</td>\n",
       "      <td>Dance</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>12</td>\n",
       "      <td>26</td>\n",
       "      <td>0</td>\n",
       "      <td>Acoustic</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>13</td>\n",
       "      <td>27</td>\n",
       "      <td>0</td>\n",
       "      <td>Acoustic</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>14</td>\n",
       "      <td>30</td>\n",
       "      <td>0</td>\n",
       "      <td>Acoustic</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>15</td>\n",
       "      <td>31</td>\n",
       "      <td>0</td>\n",
       "      <td>Classical</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>16</td>\n",
       "      <td>34</td>\n",
       "      <td>0</td>\n",
       "      <td>Classical</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>17</td>\n",
       "      <td>35</td>\n",
       "      <td>0</td>\n",
       "      <td>Classical</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "    age  gender      genre\n",
       "0    20       1     HipHop\n",
       "1    23       1     HipHop\n",
       "2    25       1       Jazz\n",
       "3    26       1       Jazz\n",
       "4    29       1       Jazz\n",
       "5    30       1       Jazz\n",
       "6    31       1  Classical\n",
       "7    33       1  Classical\n",
       "8    37       1  Classical\n",
       "9    20       0      Dance\n",
       "10   21       0      Dance\n",
       "11   25       0      Dance\n",
       "12   26       0   Acoustic\n",
       "13   27       0   Acoustic\n",
       "14   30       0   Acoustic\n",
       "15   31       0  Classical\n",
       "16   34       0  Classical\n",
       "17   35       0  Classical"
      ]
     },
     "metadata": {},
     "execution_count": 1
    }
   ],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "# Prepare Data"
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "source": [
    "X = music_data.drop(columns=['genre'])\r\n",
    "print(X, '\\n')\r\n",
    "y = music_data['genre']\r\n",
    "print(y)"
   ],
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "    age  gender\n",
      "0    20       1\n",
      "1    23       1\n",
      "2    25       1\n",
      "3    26       1\n",
      "4    29       1\n",
      "5    30       1\n",
      "6    31       1\n",
      "7    33       1\n",
      "8    37       1\n",
      "9    20       0\n",
      "10   21       0\n",
      "11   25       0\n",
      "12   26       0\n",
      "13   27       0\n",
      "14   30       0\n",
      "15   31       0\n",
      "16   34       0\n",
      "17   35       0 \n",
      "\n",
      "0        HipHop\n",
      "1        HipHop\n",
      "2          Jazz\n",
      "3          Jazz\n",
      "4          Jazz\n",
      "5          Jazz\n",
      "6     Classical\n",
      "7     Classical\n",
      "8     Classical\n",
      "9         Dance\n",
      "10        Dance\n",
      "11        Dance\n",
      "12     Acoustic\n",
      "13     Acoustic\n",
      "14     Acoustic\n",
      "15    Classical\n",
      "16    Classical\n",
      "17    Classical\n",
      "Name: genre, dtype: object\n"
     ]
    }
   ],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "# Learning and Predicting\n",
    "Decision Tree -> sklearn"
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "source": [
    "model = DecisionTreeClassifier()\r\n",
    "# check your data before running\r\n",
    "# print(music_data)\r\n",
    "model.fit(X, y)\r\n",
    "# Predict 21 year old male and 22 year old female\r\n",
    "predictions = model.predict([ [21, 1], [22, 0] ])\r\n",
    "predictions"
   ],
   "outputs": [
    {
     "output_type": "execute_result",
     "data": {
      "text/plain": [
       "array(['HipHop', 'Dance'], dtype=object)"
      ]
     },
     "metadata": {},
     "execution_count": 3
    }
   ],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "# Calculating the Accuracy\n",
    "70-80 percent for training, the others for testing -> sklearn.model_selection"
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "source": [
    "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)\n",
    "model = DecisionTreeClassifier()\n",
    "model.fit(X_train, y_train)\n",
    "predictions = model.predict(X_test)\n",
    "\n",
    "# sklearn.metrics\n",
    "score = accuracy_score(y_test, predictions)\n",
    "print(score)\n",
    "# for the first time I got a poor 0.25\n",
    "# The more data we give to our model and the cleaner the data is, we'll \\\n",
    "#   get the better result"
   ],
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "1.0\n"
     ]
    }
   ],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "# Persisting Models\n",
    "Saving and loading models trained once in a while -> sklearn.externals"
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "source": [
    "model = DecisionTreeClassifier()\n",
    "model.fit(X, y)\n",
    "\n",
    "# save\n",
    "joblib.dump(model, 'music-recommender.joblib')"
   ],
   "outputs": [
    {
     "output_type": "execute_result",
     "data": {
      "text/plain": [
       "['music-recommender.joblib']"
      ]
     },
     "metadata": {},
     "execution_count": 5
    }
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "source": [
    "# load\n",
    "trained_model = joblib.load('music-recommender.joblib')\n",
    "\n",
    "predictions = model.predict([[21, 1]])\n",
    "predictions"
   ],
   "outputs": [
    {
     "output_type": "execute_result",
     "data": {
      "text/plain": [
       "array(['HipHop'], dtype=object)"
      ]
     },
     "metadata": {},
     "execution_count": 6
    }
   ],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "# Visualizing a Decision Tree"
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "source": [
    "import pandas as pd\n",
    "from sklearn.tree import DecisionTreeClassifier\n",
    "from sklearn import tree\n",
    "\n",
    "\n",
    "music_data = pd.read_csv('music.csv')\n",
    "X = music_data.drop(columns=['genre'])\n",
    "y = music_data['genre']\n",
    "\n",
    "model = DecisionTreeClassifier()\n",
    "model.fit(X, y)\n",
    "\n",
    "tree.export_graphviz(model, out_file='music-recommender.dot',\n",
    "                    feature_names=['age', 'gender'],\n",
    "                    class_names=sorted(y.unique()),\n",
    "                    label='all',\n",
    "                    rounded=True,\n",
    "                    filled=True)"
   ],
   "outputs": [],
   "metadata": {}
  }
 ],
 "metadata": {
  "interpreter": {
   "hash": "a8f61be024eba58adef938c9aa1e29e02cb3dece83a5348b1a2dafd16a070453"
  },
  "kernelspec": {
   "name": "python3",
   "display_name": "Python 3.7.4 64-bit ('base': conda)"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}